{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "<a href=\"https://colab.research.google.com/github/milvus-io/bootcamp/blob/master/bootcamp/tutorials/quickstart/contextual_retrieval_with_milvus.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>   <a href=\"https://github.com/milvus-io/bootcamp/blob/master/bootcamp/tutorials/quickstart/contextual_retrieval_with_milvus.ipynb\" target=\"_blank\">\n",
    "    <img src=\"https://img.shields.io/badge/View%20on%20GitHub-555555?style=flat&logo=github&logoColor=white\" alt=\"GitHub Repository\"/>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "# Contextual Retrieval with Milvus\n",
    "![image](https://raw.githubusercontent.com/milvus-io/bootcamp/refs/heads/master/images/contextual_retrieval_with_milvus.png)\n",
    "[Contextual Retrieval](https://www.anthropic.com/news/contextual-retrieval) is an advanced retrieval method proposed by Anthropic to address the issue of semantic isolation of chunks, which arises in current Retrieval-Augmented Generation (RAG) solutions. In the current practical RAG paradigm, documents are divided into several chunks, and a vector database is used to search for the query, retrieving the most relevant chunks. An LLM then responds to the query using these retrieved chunks. However, this chunking process can result in the loss of contextual information, making it difficult for the retriever to determine relevance.\n",
    "\n",
    "Contextual Retrieval improves traditional retrieval systems by adding relevant context to each document chunk before embedding or indexing, boosting accuracy and reducing retrieval errors. Combined with techniques like hybrid retrieval and reranking, it enhances Retrieval-Augmented Generation (RAG) systems, especially for large knowledge bases. Additionally, it offers a cost-effective solution when paired with prompt caching, significantly reducing latency and operational costs, with contextualized chunks costing approximately $1.02 per million document tokens. This makes it a scalable and efficient approach for handling large knowledge bases. Anthropic’s solution shows two insightful aspects:\n",
    "- `Document Enhancement`: Query rewriting is a crucial technique in modern information retrieval, often using auxiliary information to make the query more informative. Similarly, to achieve better performance in RAG, preprocessing documents with an LLM (e.g., cleaning the data source, complementing lost information, summarizing, etc.) before indexing can significantly improve the chances of retrieving relevant documents. In other words, this preprocessing step helps bring the documents closer to the queries in terms of relevance.\n",
    "- `Low-Cost Processing by Caching Long Context`: One common concern when using LLMs to process documents is the cost. The KVCache is a popular solution that allows reuse of intermediate results for the same preceding context. While most hosted LLM vendors make this feature transparent to user, Anthropic gives users control over the caching process. When a cache hit occurs, most computations can be saved (this is common when the long context remains the same, but the instruction for each query changes). For more details, click [here](https://www.anthropic.com/news/prompt-caching).\n",
    "\n",
    "In this notebook, we will demonstrate how to perform contextual retrieval using Milvus with an LLM, combining dense-sparse hybrid retrieval and a reranker to create a progressively more powerful retrieval system. The data and experimental setup are based on the [contextual retrieval](https://github.com/anthropics/anthropic-cookbook/blob/main/skills/contextual-embeddings/guide.ipynb).\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Preparation\n",
    "### Install Dependencies"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install \"pymilvus[model]\"\n",
    "!pip install tqdm\n",
    "!pip install anthropic"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "> If you are using Google Colab, to enable dependencies just installed, you may need to **restart the runtime** (click on the \"Runtime\" menu at the top of the screen, and select \"Restart session\" from the dropdown menu)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You will need API keys from Cohere, Voyage, and Anthropic to run the code."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Download Data\n",
    "The following command will download the example data used in original Anthropic [demo](https://github.com/anthropics/anthropic-cookbook/blob/main/skills/contextual-embeddings/guide.ipynb)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!wget https://raw.githubusercontent.com/anthropics/anthropic-cookbook/refs/heads/main/skills/contextual-embeddings/data/codebase_chunks.json\n",
    "!wget https://raw.githubusercontent.com/anthropics/anthropic-cookbook/refs/heads/main/skills/contextual-embeddings/data/evaluation_set.jsonl"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Define Retriever"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This class is designed to be flexible, allowing you to choose between different retrieval modes based on your needs. By specifying options in the initialization method, you can determine whether to use contextual retrieval, hybrid search (combining dense and sparse retrieval methods), or a reranker for enhanced results."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pymilvus.model.dense import VoyageEmbeddingFunction\n",
    "from pymilvus.model.hybrid import BGEM3EmbeddingFunction\n",
    "from pymilvus.model.reranker import CohereRerankFunction\n",
    "\n",
    "from typing import List, Dict, Any\n",
    "from typing import Callable\n",
    "from pymilvus import (\n",
    "    MilvusClient,\n",
    "    DataType,\n",
    "    AnnSearchRequest,\n",
    "    RRFRanker,\n",
    ")\n",
    "from tqdm import tqdm\n",
    "import json\n",
    "import anthropic\n",
    "\n",
    "\n",
    "class MilvusContextualRetriever:\n",
    "    def __init__(\n",
    "        self,\n",
    "        uri=\"milvus.db\",\n",
    "        collection_name=\"contexual_bgem3\",\n",
    "        dense_embedding_function=None,\n",
    "        use_sparse=False,\n",
    "        sparse_embedding_function=None,\n",
    "        use_contextualize_embedding=False,\n",
    "        anthropic_client=None,\n",
    "        use_reranker=False,\n",
    "        rerank_function=None,\n",
    "    ):\n",
    "        self.collection_name = collection_name\n",
    "\n",
    "        # For Milvus-lite, uri is a local path like \"./milvus.db\"\n",
    "        # For Milvus standalone service, uri is like \"http://localhost:19530\"\n",
    "        # For Zilliz Clond, please set `uri` and `token`, which correspond to the [Public Endpoint and API key](https://docs.zilliz.com/docs/on-zilliz-cloud-console#cluster-details) in Zilliz Cloud.\n",
    "        self.client = MilvusClient(uri)\n",
    "\n",
    "        self.embedding_function = dense_embedding_function\n",
    "\n",
    "        self.use_sparse = use_sparse\n",
    "        self.sparse_embedding_function = None\n",
    "\n",
    "        self.use_contextualize_embedding = use_contextualize_embedding\n",
    "        self.anthropic_client = anthropic_client\n",
    "\n",
    "        self.use_reranker = use_reranker\n",
    "        self.rerank_function = rerank_function\n",
    "\n",
    "        if use_sparse is True and sparse_embedding_function:\n",
    "            self.sparse_embedding_function = sparse_embedding_function\n",
    "        elif sparse_embedding_function is False:\n",
    "            raise ValueError(\n",
    "                \"Sparse embedding function cannot be None if use_sparse is False\"\n",
    "            )\n",
    "        else:\n",
    "            pass\n",
    "\n",
    "    def build_collection(self):\n",
    "        schema = self.client.create_schema(\n",
    "            auto_id=True,\n",
    "            enable_dynamic_field=True,\n",
    "        )\n",
    "        schema.add_field(field_name=\"pk\", datatype=DataType.INT64, is_primary=True)\n",
    "        schema.add_field(\n",
    "            field_name=\"dense_vector\",\n",
    "            datatype=DataType.FLOAT_VECTOR,\n",
    "            dim=self.embedding_function.dim,\n",
    "        )\n",
    "        if self.use_sparse is True:\n",
    "            schema.add_field(\n",
    "                field_name=\"sparse_vector\", datatype=DataType.SPARSE_FLOAT_VECTOR\n",
    "            )\n",
    "\n",
    "        index_params = self.client.prepare_index_params()\n",
    "        index_params.add_index(\n",
    "            field_name=\"dense_vector\", index_type=\"FLAT\", metric_type=\"IP\"\n",
    "        )\n",
    "        if self.use_sparse is True:\n",
    "            index_params.add_index(\n",
    "                field_name=\"sparse_vector\",\n",
    "                index_type=\"SPARSE_INVERTED_INDEX\",\n",
    "                metric_type=\"IP\",\n",
    "            )\n",
    "\n",
    "        self.client.create_collection(\n",
    "            collection_name=self.collection_name,\n",
    "            schema=schema,\n",
    "            index_params=index_params,\n",
    "            enable_dynamic_field=True,\n",
    "        )\n",
    "\n",
    "    def insert_data(self, chunk, metadata):\n",
    "        dense_vec = self.embedding_function([chunk])[0]\n",
    "        if self.use_sparse is True:\n",
    "            sparse_result = self.sparse_embedding_function.encode_documents([chunk])\n",
    "            if type(sparse_result) == dict:\n",
    "                sparse_vec = sparse_result[\"sparse\"][[0]]\n",
    "            else:\n",
    "                sparse_vec = sparse_result[[0]]\n",
    "            self.client.insert(\n",
    "                collection_name=self.collection_name,\n",
    "                data={\n",
    "                    \"dense_vector\": dense_vec,\n",
    "                    \"sparse_vector\": sparse_vec,\n",
    "                    **metadata,\n",
    "                },\n",
    "            )\n",
    "        else:\n",
    "            self.client.insert(\n",
    "                collection_name=self.collection_name,\n",
    "                data={\"dense_vector\": dense_vec, **metadata},\n",
    "            )\n",
    "\n",
    "    def insert_contextualized_data(self, doc, chunk, metadata):\n",
    "        contextualized_text, usage = self.situate_context(doc, chunk)\n",
    "        metadata[\"context\"] = contextualized_text\n",
    "        text_to_embed = f\"{chunk}\\n\\n{contextualized_text}\"\n",
    "        dense_vec = self.embedding_function([text_to_embed])[0]\n",
    "        if self.use_sparse is True:\n",
    "            sparse_vec = self.sparse_embedding_function.encode_documents(\n",
    "                [text_to_embed]\n",
    "            )[\"sparse\"][[0]]\n",
    "            self.client.insert(\n",
    "                collection_name=self.collection_name,\n",
    "                data={\n",
    "                    \"dense_vector\": dense_vec,\n",
    "                    \"sparse_vector\": sparse_vec,\n",
    "                    **metadata,\n",
    "                },\n",
    "            )\n",
    "        else:\n",
    "            self.client.insert(\n",
    "                collection_name=self.collection_name,\n",
    "                data={\"dense_vector\": dense_vec, **metadata},\n",
    "            )\n",
    "\n",
    "    def situate_context(self, doc: str, chunk: str):\n",
    "        DOCUMENT_CONTEXT_PROMPT = \"\"\"\n",
    "        <document>\n",
    "        {doc_content}\n",
    "        </document>\n",
    "        \"\"\"\n",
    "\n",
    "        CHUNK_CONTEXT_PROMPT = \"\"\"\n",
    "        Here is the chunk we want to situate within the whole document\n",
    "        <chunk>\n",
    "        {chunk_content}\n",
    "        </chunk>\n",
    "\n",
    "        Please give a short succinct context to situate this chunk within the overall document for the purposes of improving search retrieval of the chunk.\n",
    "        Answer only with the succinct context and nothing else.\n",
    "        \"\"\"\n",
    "\n",
    "        response = self.anthropic_client.beta.prompt_caching.messages.create(\n",
    "            model=\"claude-3-haiku-20240307\",\n",
    "            max_tokens=1000,\n",
    "            temperature=0.0,\n",
    "            messages=[\n",
    "                {\n",
    "                    \"role\": \"user\",\n",
    "                    \"content\": [\n",
    "                        {\n",
    "                            \"type\": \"text\",\n",
    "                            \"text\": DOCUMENT_CONTEXT_PROMPT.format(doc_content=doc),\n",
    "                            \"cache_control\": {\n",
    "                                \"type\": \"ephemeral\"\n",
    "                            },  # we will make use of prompt caching for the full documents\n",
    "                        },\n",
    "                        {\n",
    "                            \"type\": \"text\",\n",
    "                            \"text\": CHUNK_CONTEXT_PROMPT.format(chunk_content=chunk),\n",
    "                        },\n",
    "                    ],\n",
    "                },\n",
    "            ],\n",
    "            extra_headers={\"anthropic-beta\": \"prompt-caching-2024-07-31\"},\n",
    "        )\n",
    "        return response.content[0].text, response.usage\n",
    "\n",
    "    def search(self, query: str, k: int = 20) -> List[Dict[str, Any]]:\n",
    "        dense_vec = self.embedding_function([query])[0]\n",
    "        if self.use_sparse is True:\n",
    "            sparse_vec = self.sparse_embedding_function.encode_queries([query])[\n",
    "                \"sparse\"\n",
    "            ][[0]]\n",
    "\n",
    "        req_list = []\n",
    "        if self.use_reranker:\n",
    "            k = k * 10\n",
    "        if self.use_sparse is True:\n",
    "            req_list = []\n",
    "            dense_search_param = {\n",
    "                \"data\": [dense_vec],\n",
    "                \"anns_field\": \"dense_vector\",\n",
    "                \"param\": {\"metric_type\": \"IP\"},\n",
    "                \"limit\": k * 2,\n",
    "            }\n",
    "            dense_req = AnnSearchRequest(**dense_search_param)\n",
    "            req_list.append(dense_req)\n",
    "\n",
    "            sparse_search_param = {\n",
    "                \"data\": [sparse_vec],\n",
    "                \"anns_field\": \"sparse_vector\",\n",
    "                \"param\": {\"metric_type\": \"IP\"},\n",
    "                \"limit\": k * 2,\n",
    "            }\n",
    "            sparse_req = AnnSearchRequest(**sparse_search_param)\n",
    "\n",
    "            req_list.append(sparse_req)\n",
    "\n",
    "            docs = self.client.hybrid_search(\n",
    "                self.collection_name,\n",
    "                req_list,\n",
    "                RRFRanker(),\n",
    "                k,\n",
    "                output_fields=[\n",
    "                    \"content\",\n",
    "                    \"original_uuid\",\n",
    "                    \"doc_id\",\n",
    "                    \"chunk_id\",\n",
    "                    \"original_index\",\n",
    "                    \"context\",\n",
    "                ],\n",
    "            )\n",
    "        else:\n",
    "            docs = self.client.search(\n",
    "                self.collection_name,\n",
    "                data=[dense_vec],\n",
    "                anns_field=\"dense_vector\",\n",
    "                limit=k,\n",
    "                output_fields=[\n",
    "                    \"content\",\n",
    "                    \"original_uuid\",\n",
    "                    \"doc_id\",\n",
    "                    \"chunk_id\",\n",
    "                    \"original_index\",\n",
    "                    \"context\",\n",
    "                ],\n",
    "            )\n",
    "        if self.use_reranker and self.use_contextualize_embedding:\n",
    "            reranked_texts = []\n",
    "            reranked_docs = []\n",
    "            for i in range(k):\n",
    "                if self.use_contextualize_embedding:\n",
    "                    reranked_texts.append(\n",
    "                        f\"{docs[0][i]['entity']['content']}\\n\\n{docs[0][i]['entity']['context']}\"\n",
    "                    )\n",
    "                else:\n",
    "                    reranked_texts.append(f\"{docs[0][i]['entity']['content']}\")\n",
    "            results = self.rerank_function(query, reranked_texts)\n",
    "            for result in results:\n",
    "                reranked_docs.append(docs[0][result.index])\n",
    "            docs[0] = reranked_docs\n",
    "        return docs\n",
    "\n",
    "\n",
    "def evaluate_retrieval(\n",
    "    queries: List[Dict[str, Any]], retrieval_function: Callable, db, k: int = 20\n",
    ") -> Dict[str, float]:\n",
    "    total_score = 0\n",
    "    total_queries = len(queries)\n",
    "    for query_item in tqdm(queries, desc=\"Evaluating retrieval\"):\n",
    "        query = query_item[\"query\"]\n",
    "        golden_chunk_uuids = query_item[\"golden_chunk_uuids\"]\n",
    "\n",
    "        # Find all golden chunk contents\n",
    "        golden_contents = []\n",
    "        for doc_uuid, chunk_index in golden_chunk_uuids:\n",
    "            golden_doc = next(\n",
    "                (\n",
    "                    doc\n",
    "                    for doc in query_item[\"golden_documents\"]\n",
    "                    if doc[\"uuid\"] == doc_uuid\n",
    "                ),\n",
    "                None,\n",
    "            )\n",
    "            if not golden_doc:\n",
    "                print(f\"Warning: Golden document not found for UUID {doc_uuid}\")\n",
    "                continue\n",
    "\n",
    "            golden_chunk = next(\n",
    "                (\n",
    "                    chunk\n",
    "                    for chunk in golden_doc[\"chunks\"]\n",
    "                    if chunk[\"index\"] == chunk_index\n",
    "                ),\n",
    "                None,\n",
    "            )\n",
    "            if not golden_chunk:\n",
    "                print(\n",
    "                    f\"Warning: Golden chunk not found for index {chunk_index} in document {doc_uuid}\"\n",
    "                )\n",
    "                continue\n",
    "\n",
    "            golden_contents.append(golden_chunk[\"content\"].strip())\n",
    "\n",
    "        if not golden_contents:\n",
    "            print(f\"Warning: No golden contents found for query: {query}\")\n",
    "            continue\n",
    "\n",
    "        retrieved_docs = retrieval_function(query, db, k=k)\n",
    "\n",
    "        # Count how many golden chunks are in the top k retrieved documents\n",
    "        chunks_found = 0\n",
    "        for golden_content in golden_contents:\n",
    "            for doc in retrieved_docs[0][:k]:\n",
    "                retrieved_content = doc[\"entity\"][\"content\"].strip()\n",
    "                if retrieved_content == golden_content:\n",
    "                    chunks_found += 1\n",
    "                    break\n",
    "\n",
    "        query_score = chunks_found / len(golden_contents)\n",
    "        total_score += query_score\n",
    "\n",
    "    average_score = total_score / total_queries\n",
    "    pass_at_n = average_score * 100\n",
    "    return {\n",
    "        \"pass_at_n\": pass_at_n,\n",
    "        \"average_score\": average_score,\n",
    "        \"total_queries\": total_queries,\n",
    "    }\n",
    "\n",
    "\n",
    "def retrieve_base(query: str, db, k: int = 20) -> List[Dict[str, Any]]:\n",
    "    return db.search(query, k=k)\n",
    "\n",
    "\n",
    "def load_jsonl(file_path: str) -> List[Dict[str, Any]]:\n",
    "    \"\"\"Load JSONL file and return a list of dictionaries.\"\"\"\n",
    "    with open(file_path, \"r\") as file:\n",
    "        return [json.loads(line) for line in file]\n",
    "\n",
    "\n",
    "def evaluate_db(db, original_jsonl_path: str, k):\n",
    "    # Load the original JSONL data for queries and ground truth\n",
    "    original_data = load_jsonl(original_jsonl_path)\n",
    "\n",
    "    # Evaluate retrieval\n",
    "    results = evaluate_retrieval(original_data, retrieve_base, db, k)\n",
    "    print(f\"Pass@{k}: {results['pass_at_n']:.2f}%\")\n",
    "    print(f\"Total Score: {results['average_score']}\")\n",
    "    print(f\"Total queries: {results['total_queries']}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now you need to initialize these models for the following experiments. You can easily switch to other models using the PyMilvus model library."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "aa37fddf0b9149e0bc6d085a0c519045",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Fetching 30 files:   0%|          | 0/30 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "dense_ef = VoyageEmbeddingFunction(api_key=\"your-voyage-api-key\", model_name=\"voyage-2\")\n",
    "sparse_ef = BGEM3EmbeddingFunction()\n",
    "cohere_rf = CohereRerankFunction(api_key=\"your-cohere-api-key\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "path = \"codebase_chunks.json\"\n",
    "with open(path, \"r\") as f:\n",
    "    dataset = json.load(f)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Experiment I: Standard Retrieval\n",
    "Standard retrieval uses only dense embeddings to retrieve related documents. In this experiment, we will use Pass@5 to reproduce the results from the original repo."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "standard_retriever = MilvusContextualRetriever(\n",
    "    uri=\"standard.db\", collection_name=\"standard\", dense_embedding_function=dense_ef\n",
    ")\n",
    "\n",
    "standard_retriever.build_collection()\n",
    "for doc in dataset:\n",
    "    doc_content = doc[\"content\"]\n",
    "    for chunk in doc[\"chunks\"]:\n",
    "        metadata = {\n",
    "            \"doc_id\": doc[\"doc_id\"],\n",
    "            \"original_uuid\": doc[\"original_uuid\"],\n",
    "            \"chunk_id\": chunk[\"chunk_id\"],\n",
    "            \"original_index\": chunk[\"original_index\"],\n",
    "            \"content\": chunk[\"content\"],\n",
    "        }\n",
    "        chunk_content = chunk[\"content\"]\n",
    "        standard_retriever.insert_data(chunk_content, metadata)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Evaluating retrieval: 100%|██████████| 248/248 [01:29<00:00,  2.77it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Pass@5: 80.92%\n",
      "Total Score: 0.8091877880184332\n",
      "Total queries: 248\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "evaluate_db(standard_retriever, \"evaluation_set.jsonl\", 5)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Experiment II: Hybrid Retrieval\n",
    "Now that we've obtained promising results with the Voyage embedding, we will move on to performing hybrid retrieval using the BGE-M3 model which generates powerful sparse embeddings. The results from dense retrieval and sparse retrieval will be combined using the Reciprocal Rank Fusion (RRF) method to produce a hybrid result."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "hybrid_retriever = MilvusContextualRetriever(\n",
    "    uri=\"hybrid.db\",\n",
    "    collection_name=\"hybrid\",\n",
    "    dense_embedding_function=dense_ef,\n",
    "    use_sparse=True,\n",
    "    sparse_embedding_function=sparse_ef,\n",
    ")\n",
    "\n",
    "hybrid_retriever.build_collection()\n",
    "for doc in dataset:\n",
    "    doc_content = doc[\"content\"]\n",
    "    for chunk in doc[\"chunks\"]:\n",
    "        metadata = {\n",
    "            \"doc_id\": doc[\"doc_id\"],\n",
    "            \"original_uuid\": doc[\"original_uuid\"],\n",
    "            \"chunk_id\": chunk[\"chunk_id\"],\n",
    "            \"original_index\": chunk[\"original_index\"],\n",
    "            \"content\": chunk[\"content\"],\n",
    "        }\n",
    "        chunk_content = chunk[\"content\"]\n",
    "        hybrid_retriever.insert_data(chunk_content, metadata)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Evaluating retrieval: 100%|██████████| 248/248 [02:09<00:00,  1.92it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Pass@5: 84.69%\n",
      "Total Score: 0.8469182027649771\n",
      "Total queries: 248\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "evaluate_db(hybrid_retriever, \"evaluation_set.jsonl\", 5)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Experiment III: Contextual Retrieval\n",
    "Hybrid retrieval shows an improvement, but the results can be further enhanced by applying a contextual retrieval method. To achieve this, we will use Anthropic's language model to prepend the context from whole document for each chunk. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "anthropic_client = anthropic.Anthropic(\n",
    "    api_key=\"your-anthropic-api-key\",\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "contextual_retriever = MilvusContextualRetriever(\n",
    "    uri=\"contextual.db\",\n",
    "    collection_name=\"contextual\",\n",
    "    dense_embedding_function=dense_ef,\n",
    "    use_sparse=True,\n",
    "    sparse_embedding_function=sparse_ef,\n",
    "    use_contextualize_embedding=True,\n",
    "    anthropic_client=anthropic_client,\n",
    ")\n",
    "\n",
    "contextual_retriever.build_collection()\n",
    "for doc in dataset:\n",
    "    doc_content = doc[\"content\"]\n",
    "    for chunk in doc[\"chunks\"]:\n",
    "        metadata = {\n",
    "            \"doc_id\": doc[\"doc_id\"],\n",
    "            \"original_uuid\": doc[\"original_uuid\"],\n",
    "            \"chunk_id\": chunk[\"chunk_id\"],\n",
    "            \"original_index\": chunk[\"original_index\"],\n",
    "            \"content\": chunk[\"content\"],\n",
    "        }\n",
    "        chunk_content = chunk[\"content\"]\n",
    "        contextual_retriever.insert_contextualized_data(\n",
    "            doc_content, chunk_content, metadata\n",
    "        )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " Evaluating retrieval: 100%|██████████| 248/248 [01:55<00:00,  2.15it/s]\n",
      "Pass@5: 87.14%\n",
      "Total Score: 0.8713517665130568\n",
      "Total queries: 248 \n"
     ]
    }
   ],
   "source": [
    "evaluate_db(contextual_retriever, \"evaluation_set.jsonl\", 5)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Experiment IV: Contextual Retrieval with Reranker\n",
    "The results can be further improved by adding a Cohere reranker. Without initializing a new retriever with reranker separately, we can simply configure the existing retriever to use the reranker for enhanced performance."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "contextual_retriever.use_reranker = True\n",
    "contextual_retriever.rerank_function = cohere_rf"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluating retrieval: 100%|██████████| 248/248 [02:02<00:00,  2.00it/s]\n",
      "Pass@5: 90.91%\n",
      "Total Score: 0.9090821812596005\n",
      "Total queries: 248\n"
     ]
    }
   ],
   "source": [
    "evaluate_db(contextual_retriever, \"evaluation_set.jsonl\", 5)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We have demonstrated several methods to improve retrieval performance. With more ad-hoc design tailored to the scenario, contextual retrieval shows significant potential to preprocess documents at a low cost, leading to a better RAG system."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
